// Copyright 2020 The Cockroach Authors.
//
// Use of this software is governed by the Business Source License
// included in the file licenses/BSL.txt.
//
// As of the Change Date specified in that file, in accordance with
// the Business Source License, use of this software will be governed
// by the Apache License, Version 2.0, included in the file
// licenses/APL.txt.
//
// Command skip-test will skip a test in the CockroachDB repo.
package main

import (
	"context"
	"flag"
	"fmt"
	"log"
	"net/http"
	"os"
	"os/exec"
	"path/filepath"
	"regexp"
	"strconv"
	"strings"

	"github.com/cockroachdb/errors"
	"github.com/google/go-github/github"
	"golang.org/x/oauth2"
)

var leakTestRegexp = regexp.MustCompile(`defer leaktest.AfterTest\(t\)\(\)`)

var (
	flags          = flag.NewFlagSet(os.Args[0], flag.ExitOnError)
	flagIssueNum   = flags.Int("issue_num", 0, "issue to link the skip to; if unset skip-test will\ntry to search for existing issues based on the test name")
	flagReason     = flags.String("reason", "flaky test", "reason to put under skip")
	flagUnderRace  = flags.Bool("under_race", false, "if true, only skip under race")
	flagUnderBazel = flags.Bool("under_bazel", false, "if true, only skip under bazel")
)

const description = `The skip-test utility creates a pull request to skip a test.

Example usage:

    ./bin/skip-test -issue_num 1234 pkg/to/test:TestToSkip

The following options are available:

`

func usage() {
	fmt.Fprint(flags.Output(), description)
	flags.PrintDefaults()
	fmt.Println("")
}

func main() {
	flags.Usage = usage
	if err := flags.Parse(os.Args[1:]); err != nil {
		usage()
		log.Fatal(err)
	}

	if len(flags.Args()) != 1 {
		usage()
		log.Fatalf("missing required argument: `TestName` or `pkg/to/test:TestToSkip`")
	}

	ctx := context.Background()

	remote, _ := capture("git", "config", "--get", "cockroach.remote")
	if remote == "" {
		log.Fatalf("set cockroach.remote to the name of the Git remote to push")
	}

	var ghAuthClient *http.Client
	ghToken, _ := capture("git", "config", "--get", "cockroach.githubToken")
	if ghToken != "" {
		ghAuthClient = oauth2.NewClient(ctx, oauth2.StaticTokenSource(&oauth2.Token{AccessToken: ghToken}))
	}
	ghClient := github.NewClient(ghAuthClient)

	arg := flags.Args()[0]
	var pkgName, testName string
	splitArg := strings.Split(arg, ":")
	switch len(splitArg) {
	case 1:
		testName = splitArg[0]
	case 2:
		pkgName = splitArg[0]
		testName = splitArg[1]
	default:
		log.Fatalf("expected test to be of format `TestName` or `pkg/to/test:TestToSkip`, found %s", arg)
	}

	if *flagUnderBazel && *flagUnderRace {
		log.Fatal("cannot use both -under_race and -under_bazel")
	}

	// Check git status is clean.
	if err := spawn("git", "diff", "--exit-code"); err != nil {
		log.Fatal(errors.Wrap(err, "git state may not be clean, please use `git stash` or commit changes before proceeding."))
	}

	// Do a fresh checkout of master.
	if err := spawn("git", "fetch", "https://github.com/cockroachdb/cockroach.git", "master"); err != nil {
		log.Fatal(errors.Wrap(err, "failed to get CockroachDB master"))
	}

	skipBranch := fmt.Sprintf("skip-test-%s", testName)
	// Checkout a fresh HEAD.
	if err := spawn("git", "checkout", "-b", skipBranch, "FETCH_HEAD"); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to checkout branch %s", skipBranch))
	}

	// Grep for all tests underneath the given package.
	searchPath := pkgName
	if searchPath == "" {
		searchPath = "./pkg"
	}
	fnGrep := fmt.Sprintf(`func %s(t \*testing\.T)`, regexp.QuoteMeta(testName))
	grepOutput, err := capture(
		"git",
		"grep",
		"-i",
		fnGrep,
		searchPath,
	)
	if err != nil {
		log.Fatal(errors.Wrapf(err, "failed to grep for the failing test"))
	}

	grepOutputSplit := strings.Split(grepOutput, "\n")
	if len(grepOutputSplit) != 1 {
		log.Fatalf("expected 1 result for test %s, found %d:\n%s", arg, len(grepOutputSplit), grepOutput)
	}
	fileName := strings.Split(grepOutput, ":")[0]

	// Find an issue number.
	pkgPrefix := strings.TrimPrefix(filepath.Dir(fileName), "pkg/")
	issueNum := findIssue(ctx, ghClient, pkgPrefix, testName)

	// Replace the file with the skip status.
	replaceFile(fileName, testName, issueNum)

	// Update the package's BUILD.bazel.
	devPath, err := exec.LookPath("./dev")
	if err != nil {
		fmt.Printf("./dev not found, trying dev\n")
		devPath, err = exec.LookPath("dev")
		if err != nil {
			log.Fatal(errors.Wrapf(err, "no path found for dev"))
		}
	}
	if err := spawn(devPath, "generate", "bazel"); err != nil {
		log.Fatal(errors.Wrap(err, "failed to run bazel"))
	}

	// Commit the file.
	if err := spawn("git", "add", searchPath); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to add %s to commit", searchPath))
	}
	var modifierStr string
	if *flagUnderRace {
		modifierStr = " under race"
	} else if *flagUnderBazel {
		modifierStr = " under bazel"
	}
	commitMsg := fmt.Sprintf(`%s: skip %s%s

Refs: #%d

Reason: %s

Generated by bin/skip-test.

Release justification: non-production code changes
Release note: None
Epic: None
`, pkgPrefix, testName, modifierStr, issueNum, *flagReason)
	if err := spawn("git", "commit", "-m", commitMsg); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to commit %s", fileName))
	}

	// Create a PR.
	if err := spawn("git", "push", "--force", remote, fmt.Sprintf("%[1]s:%[1]s", skipBranch, skipBranch)); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to push to remote"))
	}

	// The shorthand for opening a web browser with Python, `python -m
	// webbrowser URL`, does not set the status code appropriately.
	// TODO(mgartner): A GitHub username should be used instead of remote in the
	// URL. A remote might not be the same as the username, in which case an
	// incorrect URL is opened. For example, the remote might be named "origin"
	// but the username is "mgartner".
	if err := spawn(
		"python",
		"-c",
		"import sys, webbrowser; sys.exit(not webbrowser.open(sys.argv[1]))",
		fmt.Sprintf("https://github.com/cockroachdb/cockroach/compare/master...%s:%s?expand=1", remote, skipBranch),
	); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to open web browser"))
	}

	// Go back to a previous branch.
	if err := checkoutPrevious(); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to checkout previous branch"))
	}
}

func replaceFile(fileName, testName string, issueNum int) {
	fileContents, err := os.ReadFile(fileName)
	if err != nil {
		log.Fatal(errors.Wrapf(err, "failed to read file: %s", fileName))
	}

	// Screw efficiency, do the dumb thing.
	lines := strings.Split(string(fileContents), "\n")
	r := regexp.MustCompile(fmt.Sprintf(`func %s\(t \*testing\.T\)`, regexp.QuoteMeta(testName)))
	lineIdx := -1
	for i, line := range lines {
		if r.MatchString(line) {
			lineIdx = i
			break
		}
	}
	if lineIdx == -1 {
		log.Fatalf("failed to find test output %s in %s", testName, fileName)
	}

	// Wait until we've passed the "leaktest" line.
	insertLineIdx := lineIdx + 1
	for leakTestRegexp.MatchString(lines[insertLineIdx]) {
		insertLineIdx++
	}

	newLines := append(
		[]string{},
		lines[:insertLineIdx]...,
	)
	if *flagUnderRace {
		newLines = append(
			newLines,
			fmt.Sprintf(`skip.UnderRaceWithIssue(t, %d, "%s")`, issueNum, *flagReason),
		)
	} else if *flagUnderBazel {
		newLines = append(
			newLines,
			fmt.Sprintf(`skip.UnderBazelWithIssue(t, %d, "%s")`, issueNum, *flagReason),
		)
	} else {
		newLines = append(
			newLines,
			fmt.Sprintf(`skip.WithIssue(t, %d, "%s")`, issueNum, *flagReason),
		)
	}
	newLines = append(
		newLines,
		lines[insertLineIdx:]...,
	)

	if err := os.WriteFile(fileName, []byte(strings.Join(newLines, "\n")), 0644); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to write file %s", fileContents))
	}

	// Run crlfmt on the file.
	if err := spawn("bin/crlfmt", "-w", "-tab", "2", fileName); err != nil {
		log.Fatal(errors.Wrapf(err, "failed to run crlfmt on %s", fileName))
	}
}

func checkoutPrevious() error {
	branch, err := capture("git", "rev-parse", "--abbrev-ref", "HEAD")
	if err != nil {
		return errors.Newf("looking up current branch name: %w", err)
	}
	if !regexp.MustCompile(`^skip-test-.*`).MatchString(branch) {
		return nil
	}
	if err := spawn("git", "checkout", "-"); err != nil {
		return errors.Newf("returning to previous branch: %w", err)
	}
	return nil
}

func findIssue(ctx context.Context, ghClient *github.Client, pkgPrefix, testName string) int {
	issueNum := *flagIssueNum
	if issueNum == 0 {
		searched, _, err := ghClient.Search.Issues(
			ctx,
			fmt.Sprintf(`"%s: %s failed" in:title is:open is:issue`, pkgPrefix, testName),
			nil,
		)
		if err != nil {
			log.Fatal(errors.Wrap(err, "failed searching for issue"))
		}
		if len(searched.Issues) != 1 {
			var issues []string
			for _, issue := range searched.Issues {
				issues = append(issues, strconv.Itoa(issue.GetNumber()))
			}
			log.Fatal(errors.Newf("found 0 or multiple issues for %s: %s\nuse --issue_num=<num> to attach a created issue.", testName, strings.Join(issues, ",")))
		}
		return searched.Issues[0].GetNumber()
	}
	return issueNum
}
